import copy

import torch

from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.factor_utils import final_conv_args, final_mlp_args
from Network.General.Factor.factored import return_values
from Network.General.Flat.mlp import MLPNetwork
from Network.General.GNN.gnn import GraphNetwork
from Network.network import Network
from Network.network_utils import reduce_function


class GraphPairNetwork(Network):
    def __init__(self, args):
        super().__init__(args)
        self.fp = args.factor
        self.reduce_function = args.factor_net.reduce_function
        self.num_layers = args.factor_net.num_pair_layers
        self.embed_dim = args.embed_dim
        self.append_keys = args.factor_net.append_keys
        self.aggregate_passive = True
        layers = list()

        gnn_args = copy.deepcopy(args)
        kq_dim = args.factor.key_dim + args.factor.query_dim if self.append_keys or self.append_zero_keys else args.factor.query_dim
        gnn_args.object_dim = self.embed_dim if self.embed_dim > 0 else kq_dim
        gnn_args.output_dim = self.embed_dim if self.embed_dim > 0 else args.output_dim
        gnn_args.activation_final = gnn_args.activation if self.embed_dim > 0 else args.activation_final
        self.gnn_args = gnn_args
        print("embed_dim=====================================================================", self.embed_dim)
        self.gnn_layer = GraphNetwork(gnn_args)

        layers.append(self.gnn_layer)

        args.factor.final_embed_dim = self.embed_dim if self.embed_dim > 0 else args.factor.key_dim + args.factor.query_dim
        self.aggregate_final = args.aggregate_final
        # self.softmax = nn.Softmax(-1)
        if args.aggregate_final:  # does not work with a post-channel
            final_args = final_mlp_args(args)
            self.decode = MLPNetwork(final_args)
            layers.append(self.decode)
        else:
            # need a network to go from the embed_dim to the object_dim
            if self.embed_dim > 0:
                final_args = final_conv_args(args)
                self.decode = ConvNetwork(final_args)
                layers.append(self.decode)

        self.model = layers
        self.train()
        self.reset_network_parameters()

    def forward(self, key, query, mask, ret_settings):
        # query: [500, 5, 9], mask: [500, 1, 5]
        num_obj = query.shape[1]
        bs = query.shape[0]
        '''row = torch.arange(bs * (num_obj + 1))
        row_idx_del = torch.arange(5, row.numel(), 6)
        row_idx_del_set = set(row_idx_del.tolist())
        row = torch.index_select(row, 0, torch.tensor(list(set(range(row.numel())) - row_idx_del_set)))
        col = row_idx_del.unsqueeze(1).repeat(1, num_obj).view(-1)
        del_idx = torch.nonzero(mask.view(-1) == 0).view(-1)
        idx_mask = torch.ones_like(row, dtype=torch.bool)
        idx_mask[del_idx] = False
        row = row[idx_mask]
        col = col[idx_mask]'''
        tot_idx_num = bs * (num_obj + 1)
        tot = torch.arange(tot_idx_num).to(key.device)
        row_idx_del_1 = torch.arange(num_obj, tot_idx_num, num_obj + 1).to(key.device)
        row_idx_del_set_1 = set(row_idx_del_1.tolist())
        row_1 = torch.index_select(tot, 0,
                                   torch.tensor(list(set(range(tot_idx_num)) - row_idx_del_set_1), device=key.device))
        col_1 = row_idx_del_1.unsqueeze(1).repeat(1, num_obj).view(-1)
        row_2 = row_idx_del_1.unsqueeze(1).repeat(1, num_obj - 1).view(-1)
        col_idx_del = torch.arange(num_obj - 1, len(row_1), num_obj)
        col2_mask = torch.ones_like(row_1, dtype=torch.bool)
        col2_mask[col_idx_del] = 0
        col_2 = row_1[col2_mask]
        row = torch.cat((row_1, row_2, col_2))
        col = torch.cat((col_1, col_2, col_2))
        '''mask_with_key = torch.ones((bs, 1, (num_obj + 1)))
        mask_with_key[:, :, :mask.shape[2]] = mask
        del_idx_with_key = torch.nonzero(mask.view(-1) == 0).view(-1)
        idx_mask_with_key = torch.ones(bs * (num_obj + 1), dtype=torch.bool)
        idx_mask_with_key[del_idx_with_key] = False'''
        edge_index = torch.stack([row, col], dim=0)
        mask_with_key = torch.ones((bs, 1, (num_obj + 1)))
        mask_with_key[:, :, :mask.shape[2]] = mask
        del_idx_with_key = torch.nonzero(mask.view(-1) == 0).view(-1)
        mask_edge = ~torch.isin(edge_index, del_idx_with_key).any(dim=0)
        edge_index = edge_index[:, mask_edge]
        x = self.gnn_layer((query, key), edge_index)
        x = x.view(bs, num_obj + 1, -1)
        x = x.transpose(2, 1)
        if self.aggregate_passive:
            x_mean = x[:, :, [num_obj - 1, num_obj]].mean(dim=2)
            x_mean = x_mean.unsqueeze(dim=2)
            x = torch.cat([x[:, :, :-2], x_mean], dim=2)
        '''row = torch.arange(bs * num_obj)
        col = torch.arange(bs).unsqueeze(1).repeat(1, num_obj).view(-1)
        del_idx = torch.nonzero(mask.view(-1) == 0).view(-1)
        idx_mask = torch.ones_like(row, dtype=torch.bool)
        idx_mask[del_idx] = False
        # Use advanced indexing to select elements to keep
        row = row[idx_mask]
        col = col[idx_mask]
        x = self.gnn_layer((query, key), (row, col), s_N=len(row), t_N=bs)'''

        embeddings, reduction = x, None
        if self.aggregate_final:
            x = reduce_function(self.reduce_function, x)
            x = x.view(x.shape[0], -1)  # [bs, 4]
            reduction = x
            if self.embed_dim > 0: x = self.decode(x)
        else:
            if self.embed_dim > 0: x = self.decode(x)
            x = x.transpose(2, 1)
            x = x.reshape(x.shape[0], -1)
        return return_values(ret_settings, x, (key, query), embeddings,
                             reduction)